Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental Redux conditioning for Flux Lora training #1838

Draft
wants to merge 2 commits into
base: sd3
Choose a base branch
from

Conversation

recris
Copy link

@recris recris commented Dec 15, 2024

This PR adds support for training Flux.1 LoRA using conditioning from the Redux image encoder.

Instead of relying on text captions to condition the model, why not use the image itself to provide a "perfect" caption instead?

Redux+SigLIP provide a T5 compatible embedding that generates images very close to the target. I thought this could be used instead of relying on text descriptions that may or may not match the concepts as understood by the base model.

To use this I've added the following new parameters:

  • redux_model_path: Safetensors file for the Redux model (downloadable from here)
    • Note: the code will also pull the SigLIP model from HuggingFace (google/siglip-so400m-patch14-384)
  • vision_cond_downsample: this controls downsampling for Redux tokens. By default, Redux conditioning uses a 27x27 set of tokens, which is a lot and has a very strong effect preventing proper learning. By setting this parameter to N the tokens will be downsampled to a NxN grid, thus reducing the effect. (By default this is disabled)
  • vision_cond_dropout: probability of drop-out for the vision conditioning. During a training step this will randomly chose to ignore the vision conditioning and use the text conditioning instead. For example 0.2 means it will use Redux 80% of the time and use regular captions for the other 20%

Experimental Notes:

  • Redux is extremely good at describing a target image, to the point where a LoRA trained solely with it becomes very weak when used without Redux. Because the conditioning is so good, it lowers the average loss significantly and the resulting LoRA learns a lot less - it essentially learns the "difference" between Base model + Redux and the training data. To mitigate this I added the dropout parameter so that during training it sees normal text prompts and avoids becoming dependent on Redux for inference.
  • The conditioning from the vision encoder is very strong, when using vision_cond_ratio I usually have to set it to 0.2 or lower before I start seeing meaningful differences on what gets learned.
  • Using vision_cond_dropout = 0.5 seems to work well enough, I noticed an improvement on the end result, less "broken" images (bad anatomy, etc.) during inference.
  • This might be a good option for training styles, given that use-case tends to require better quality, more complete descriptions in captions
  • Using this with full finetune is not supported, but there should be no technical restriction to support it. I just don't have the hardware to test it.
  • This is not a replacement for text captions, the changes only affect T5 conditioning, CLIP still needs text captions like before.
  • The interpolation method behind vision_cond_ratio feels very crude and unsound to me, maybe there is a better approach?
    • This was replaced by a down-sampling method, it seems to work better. vision_cond_downsample = 5 seems like a good place to start. Note: now the training uses both text and Redux tokens simultaneously.

I don't expect this PR to be merged anytime soon, had to make some sub-optimal code changes to make this work. I am just posting this for visibility, so that people can play with it and gather feedback.

@recris recris marked this pull request as draft December 15, 2024 21:32
@FurkanGozukara
Copy link

@recris amazing work

did you notice this is solving issue of training multiple same class concept?

like 2 man at the same time

or when you train a man it makes all other mans to turn into you.

is this solving this problem

moreover, after training, you dont need to use redux right with vision_cond_dropout = 0.5 + vision_cond_ratio = 0.2

@recris
Copy link
Author

recris commented Dec 15, 2024

@recris amazing work

did you notice this is solving issue of training multiple same class concept?

like 2 man at the same time

or when you train a man it makes all other mans to turn into you.

is this solving this problem

moreover, after training, you dont need to use redux right with vision_cond_dropout = 0.5 + vision_cond_ratio = 0.2

This has nothing to do with either of those issues. For multiple concepts you would need something like pivotal tuning which currently is not supported either.

This PR is only an attempt to improve overall quality in the presence of poorly captioned training data.

@FurkanGozukara
Copy link

@recris thanks but you still recommend vision_cond_dropout = 0.5 + vision_cond_ratio = 0.2 and then we can use trained lora without flux redux right?

@recris
Copy link
Author

recris commented Dec 15, 2024

Please read the notes fully before posting - these are not "recommendations", this hardly has been tested in a comprehensive way and it probably is not ready for widespread use.

That said, you can probably start with vision_cond_dropout = 0.5, vision_cond_ratio = 1.0. Beware that this could also require changes to the learning rate or total number of steps trained to achieve same results as before.

@dxqbYD
Copy link

dxqbYD commented Dec 19, 2024

Interesting concept!
About the LoRA only learning the difference between (base model + image conditioning) and training data:

if you consider this a downside and want a stand-alone LoRA as output, you could try to (gradually?) remove the image conditioning from the model prediction, but still expect the model to have learned making the same prediction as if it was still conditioned. Similar to this concept:

Nerogar/OneTrainer#505

but not using the base model as teacher, but the base model conditioned by Redux.
A LoRA that replicates the Redux conditioning - but without Redux - could be the result, which could then be improved by regular training on data.

@recris
Copy link
Author

recris commented Dec 20, 2024

if you consider this a downside and want a stand-alone LoRA as output, you could try to (gradually?) remove the image conditioning from the model prediction, but still expect the model to have learned making the same prediction as if it was still conditioned

This is what the vision_cond_dropout can be used for - you feed the model a mix of caption conditioned plus Redux conditioned samples so it learns to not become dependent on Redux. A value of at least 0.5 seems to do the trick, but maybe you can even go lower.

@recris
Copy link
Author

recris commented Dec 28, 2024

After some additional testing, in my (subjective) experience I can set vision_cond_dropout as low as 0.25 and still get good results.

I still don't think mashing together text and vision conditioning is the right approach. The correct way to do this would require concatenating the embeddings, then use a non-binary attention mask to control the influence of the Redux tokens. But this would require deeper changes to the Flux model code.

@dxqbYD
Copy link

dxqbYD commented Jan 3, 2025

if you consider this a downside and want a stand-alone LoRA as output, you could try to (gradually?) remove the image conditioning from the model prediction, but still expect the model to have learned making the same prediction as if it was still conditioned

This is what the vision_cond_dropout can be used for - you feed the model a mix of caption conditioned plus Redux conditioned samples so it learns to not become dependent on Redux. A value of at least 0.5 seems to do the trick, but maybe you can even go lower.

If I understand your code correctly, the dropout removes the Redux-conditioning. Therefore, in a dropout-step you are training normally against the training images.
Is that really helpful towards making the resulting LoRA work well without Redux? You are basically optimizing towards the training images under two conditionings independently from each other, with and without Redux - but then during inference Redux is not available.

I have now experimented myself a bit with this idea: On the left is the (only) training image, middle and right are samples of the LoRA, without Redux:
image
Only a few minutes of training each, 100-200 steps. The LoRA was trained with a Redux-conditioned teacher prediction, but was itself not Redux-conditioned - so there is no mismatch between training and inference.

Here is some experimental code: Nerogar/OneTrainer@master...dxqbYD:OneTrainer:redux

@recris
Copy link
Author

recris commented Jan 3, 2025

I am still running some tests; I had a configuration mistake which might have affected my earlier results.

With Redux dropout my expectation was that it would "diversify" the captions seen during training, improving learning robustness; the concept here is similar to providing multiple captions for the same image with varying levels of detail and selecting one at random in each training step (for reference: #1643)

@dxqbYD
Copy link

dxqbYD commented Jan 3, 2025

I am still running some tests; I had a configuration mistake which might have affected my earlier results.

With Redux dropout my expectation was that it would "diversify" the captions seen during training, improving learning robustness; the concept here is similar to providing multiple captions for the same image with varying levels of detail and selecting one at random in each training step (for reference: #1643)

I get the idea, but what I am saying is:

you are doing "multi-caption training",
a) one of which is an actual caption and
b) the other one is an image embedding.

And then you are done with training, and during inference you can only use a). You cannot access what the model has learned for b) because you don't have the embedding.

That's why I'm proposing that there should only be one optimization target - the one you still have during inference. That's what I've done in my samples.

The central line in my code is this:
https://github.com/dxqbYD/OneTrainer/blob/7d0e8878316d477f00d46d512ad39c3d750f7e42/modules/trainer/GenericTrainer.py#L700

@wcde
Copy link

wcde commented Jan 5, 2025

@dxqbYD I think what you're talking about is only valid for unknown things. For example, if we're training a specific person, then we naturally want to give the model his name manually. But for some known concepts, for example, training cars to be less broken, we don't need to manually tell the model that a car is a car. Siglip will do it for us perfectly.
In my tests, Redux really reduce efficiency of training exact concepts, but at the same time, some general details start to look better.

@dxqbYD
Copy link

dxqbYD commented Jan 6, 2025

@dxqbYD I think what you're talking about is only valid for unknown things. For example, if we're training a specific person, then we naturally want to give the model his name manually. But for some known concepts, for example, training cars to be less broken, we don't need to manually tell the model that a car is a car. Siglip will do it for us perfectly. In my tests, Redux really reduce efficiency of training exact concepts, but at the same time, some general details start to look better.

sorry, I don't understand how any of this is related to what I wrote before.

@recris
Copy link
Author

recris commented Jan 6, 2025

I've changed my previous approach to controlling the strength of the Redux conditioning:

  • Now it performs downsampling of the Redux tokens to a N by N grid controlled by the vision_cond_downsample parameter
  • Text tokens are always used during training and Redux tokens are concatenated. Dropout feature still works as before.

The amount of Redux tokens in the conditioning seems to affect the ability to make the LoRA usable with text prompts, from testing various downsampling sizes I've noticed that beyond N=5 (25 tokens) it starts to perform noticeably worse (unless counteracted with dropout)

Also training with both text and Redux seems to have a negative performance impact due to the amount of tokens being used, but there is also a significant amount of padding being added by default. I recommend lowering t5xxl_max_token_length if not using very long captions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants